package com.example.security;

import com.example.util.SecurityUtils;
import com.fasterxml.jackson.databind.ObjectMapper;
import jakarta.servlet.*;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.stereotype.Component;

import java.io.IOException;
import java.time.LocalDateTime;
import java.util.HashMap;
import java.util.Map;

/**
 * 综合安全过滤器 - 整合所有安全检查
 */
@Slf4j
@Component
@RequiredArgsConstructor
public class ComprehensiveSecurityFilter implements Filter {

    private final SecurityUtils securityUtils;
    private final SessionHijackingProtection sessionProtection;
    private final ApiAbuseProtection apiAbuseProtection;
    private final ObjectMapper objectMapper;

    @Override
    public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
            throws IOException, ServletException {

        HttpServletRequest httpRequest = (HttpServletRequest) request;
        HttpServletResponse httpResponse = (HttpServletResponse) response;

        try {
            // 1. 基础安全检查
            if (!performBasicSecurityChecks(httpRequest, httpResponse)) {
                return;
            }

            // 2. 输入验证
            if (!validateInput(httpRequest, httpResponse)) {
                return;
            }

            // 3. 会话安全检查
            if (!validateSession(httpRequest, httpResponse)) {
                return;
            }

            // 4. API滥用检查
            if (!checkApiAbuse(httpRequest, httpResponse)) {
                return;
            }

            // 5. 添加安全响应头
            addSecurityHeaders(httpResponse);

            // 继续处理请求
            chain.doFilter(request, response);

        } catch (Exception e) {
            log.error("综合安全过滤器异常: {}", e.getMessage(), e);
            sendErrorResponse(httpResponse, "SECURITY_ERROR", "安全检查异常");
        }
    }

    /**
     * 基础安全检查
     */
    private boolean performBasicSecurityChecks(HttpServletRequest request, HttpServletResponse response) 
            throws IOException {
        
        String requestUri = request.getRequestURI();
        String method = request.getMethod();
        String userAgent = request.getHeader("User-Agent");

        // 检查HTTP方法
        if (!isValidHttpMethod(method)) {
            log.warn("无效的HTTP方法: method={}, uri={}", method, requestUri);
            sendErrorResponse(response, "INVALID_METHOD", "无效的HTTP方法");
            return false;
        }

        // 检查User-Agent
        if (userAgent == null || userAgent.trim().isEmpty()) {
            log.warn("缺少User-Agent: uri={}", requestUri);
            sendErrorResponse(response, "MISSING_USER_AGENT", "缺少User-Agent");
            return false;
        }

        // 检查恶意User-Agent
        if (isMaliciousUserAgent(userAgent)) {
            log.warn("恶意User-Agent: userAgent={}, uri={}", userAgent, requestUri);
            sendErrorResponse(response, "MALICIOUS_USER_AGENT", "恶意User-Agent");
            return false;
        }

        // 检查请求URI长度
        if (requestUri.length() > 2048) {
            log.warn("请求URI过长: length={}, uri={}", requestUri.length(), requestUri);
            sendErrorResponse(response, "URI_TOO_LONG", "请求URI过长");
            return false;
        }

        return true;
    }

    /**
     * 输入验证
     */
    private boolean validateInput(HttpServletRequest request, HttpServletResponse response) 
            throws IOException {
        
        String requestUri = request.getRequestURI();
        String queryString = request.getQueryString();

        // 检查URI中的恶意内容
        if (!securityUtils.isSecureInput(requestUri)) {
            log.warn("URI包含恶意内容: uri={}", requestUri);
            sendErrorResponse(response, "MALICIOUS_URI", "请求包含恶意内容");
            return false;
        }

        // 检查查询参数中的恶意内容
        if (queryString != null && !securityUtils.isSecureInput(queryString)) {
            log.warn("查询参数包含恶意内容: query={}", queryString);
            sendErrorResponse(response, "MALICIOUS_QUERY", "查询参数包含恶意内容");
            return false;
        }

        // 检查请求头中的恶意内容
        if (!validateHeaders(request)) {
            log.warn("请求头包含恶意内容: uri={}", requestUri);
            sendErrorResponse(response, "MALICIOUS_HEADERS", "请求头包含恶意内容");
            return false;
        }

        return true;
    }

    /**
     * 验证会话安全
     */
    private boolean validateSession(HttpServletRequest request, HttpServletResponse response) 
            throws IOException {
        
        // 跳过公开端点
        if (isPublicEndpoint(request.getRequestURI())) {
            return true;
        }

        Authentication auth = SecurityContextHolder.getContext().getAuthentication();
        if (auth == null || !auth.isAuthenticated()) {
            return true; // 未认证的请求由其他过滤器处理
        }

        // 获取JWT token
        String token = getTokenFromRequest(request);
        if (token != null) {
            if (!sessionProtection.validateSession(token, request)) {
                log.warn("会话安全验证失败: user={}", auth.getName());
                sendErrorResponse(response, "SESSION_SECURITY_VIOLATION", "会话安全验证失败");
                return false;
            }
        }

        return true;
    }

    /**
     * API滥用检查
     */
    private boolean checkApiAbuse(HttpServletRequest request, HttpServletResponse response) 
            throws IOException {
        
        // 跳过公开端点
        if (isPublicEndpoint(request.getRequestURI())) {
            return true;
        }

        Authentication auth = SecurityContextHolder.getContext().getAuthentication();
        if (auth == null || !auth.isAuthenticated()) {
            return true;
        }

        String userId = auth.getName();
        String apiPath = request.getRequestURI();

        if (!apiAbuseProtection.checkApiAbuse(userId, apiPath, request)) {
            log.warn("API滥用检测: user={}, path={}", userId, apiPath);
            sendErrorResponse(response, "API_ABUSE_DETECTED", "检测到API滥用");
            return false;
        }

        return true;
    }

    /**
     * 添加安全响应头
     */
    private void addSecurityHeaders(HttpServletResponse response) {
        // 防止缓存敏感信息
        response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate");
        response.setHeader("Pragma", "no-cache");
        response.setHeader("Expires", "0");
        
        // 隐藏服务器信息
        response.setHeader("Server", "");
        
        // 防止MIME嗅探
        response.setHeader("X-Content-Type-Options", "nosniff");
        
        // 防止XSS
        response.setHeader("X-XSS-Protection", "1; mode=block");
        
        // 防止点击劫持
        response.setHeader("X-Frame-Options", "DENY");
        
        // 权限策略
        response.setHeader("Permissions-Policy", 
            "geolocation=(), microphone=(), camera=(), payment=(), usb=()");
        
        // 引用策略
        response.setHeader("Referrer-Policy", "strict-origin-when-cross-origin");
    }

    /**
     * 检查是否为有效的HTTP方法
     */
    private boolean isValidHttpMethod(String method) {
        return method.matches("GET|POST|PUT|DELETE|PATCH|OPTIONS|HEAD");
    }

    /**
     * 检查是否为恶意User-Agent
     */
    private boolean isMaliciousUserAgent(String userAgent) {
        String[] maliciousPatterns = {
            "sqlmap", "nmap", "nikto", "dirb", "gobuster", "masscan",
            "nessus", "openvas", "burp", "zap", "w3af", "skipfish"
        };
        
        String lowerUserAgent = userAgent.toLowerCase();
        for (String pattern : maliciousPatterns) {
            if (lowerUserAgent.contains(pattern)) {
                return true;
            }
        }
        
        return false;
    }

    /**
     * 验证请求头
     */
    private boolean validateHeaders(HttpServletRequest request) {
        java.util.Enumeration<String> headerNames = request.getHeaderNames();

        while (headerNames.hasMoreElements()) {
            String headerName = headerNames.nextElement();
            String headerValue = request.getHeader(headerName);

            // 跳过常见的安全请求头
            if (isCommonSafeHeader(headerName)) {
                continue;
            }

            // 检查头名称
            if (!securityUtils.isSecureInput(headerName)) {
                return false;
            }

            // 检查头值（对User-Agent等常见头更宽松）
            if (headerValue != null && !isSecureHeaderValue(headerName, headerValue)) {
                return false;
            }
        }

        return true;
    }

    /**
     * 检查是否为常见的安全请求头
     */
    private boolean isCommonSafeHeader(String headerName) {
        String[] safeHeaders = {
            "user-agent", "accept", "accept-language", "accept-encoding",
            "connection", "host", "referer", "origin", "sec-fetch-site",
            "sec-fetch-mode", "sec-fetch-dest", "sec-ch-ua", "sec-ch-ua-mobile",
            "sec-ch-ua-platform", "cache-control", "pragma", "upgrade-insecure-requests",
            "content-type", "content-length", "authorization", "cookie",
            "x-requested-with", "x-forwarded-for", "x-real-ip"
        };

        String lowerHeaderName = headerName.toLowerCase();
        for (String safeHeader : safeHeaders) {
            if (lowerHeaderName.equals(safeHeader)) {
                return true;
            }
        }

        return false;
    }

    /**
     * 检查请求头值是否安全
     */
    private boolean isSecureHeaderValue(String headerName, String headerValue) {
        // 对User-Agent使用专门的检查
        if ("user-agent".equalsIgnoreCase(headerName)) {
            return !isMaliciousUserAgent(headerValue);
        }

        // 对其他头使用通用安全检查
        return securityUtils.isSecureInput(headerValue);
    }

    /**
     * 检查是否为公开端点
     */
    private boolean isPublicEndpoint(String uri) {
        String[] publicEndpoints = {
            "/api/auth/", "/api/public/", "/api/payments/callback/",
            "/ws", "/actuator/", "/swagger-ui/", "/v3/api-docs",
            "/error", "/favicon.ico", "/h2-console/", "/static/",
            "/css/", "/js/", "/images/", "/fonts/"
        };

        for (String endpoint : publicEndpoints) {
            if (uri.startsWith(endpoint)) {
                return true;
            }
        }

        return false;
    }

    /**
     * 从请求中获取JWT token
     */
    private String getTokenFromRequest(HttpServletRequest request) {
        String bearerToken = request.getHeader("Authorization");
        if (bearerToken != null && bearerToken.startsWith("Bearer ")) {
            return bearerToken.substring(7);
        }
        return null;
    }

    /**
     * 发送错误响应
     */
    private void sendErrorResponse(HttpServletResponse response, String errorCode, String message) 
            throws IOException {
        response.setStatus(HttpServletResponse.SC_BAD_REQUEST);
        response.setContentType("application/json;charset=UTF-8");
        response.setCharacterEncoding("UTF-8");
        
        Map<String, Object> errorResponse = new HashMap<>();
        errorResponse.put("success", false);
        errorResponse.put("error", errorCode);
        errorResponse.put("message", message);
        errorResponse.put("timestamp", LocalDateTime.now().toString());
        
        response.getWriter().write(objectMapper.writeValueAsString(errorResponse));
    }
}
